Skip to content

Fix save_hyperparameters not crashing on dataclass with init=False #21051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from

Conversation

SkafteNicki
Copy link
Contributor

@SkafteNicki SkafteNicki commented Aug 11, 2025

What does this PR do?

Fixes #21036
Skip attributes where user have set `init=False

Before submitting
  • Was this discussed/agreed via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

Reviewer checklist
  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

📚 Documentation preview 📚: https://pytorch-lightning--21051.org.readthedocs.build/en/21051/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 11, 2025
@QuentinSoubeyranAqemia
Copy link

QuentinSoubeyranAqemia commented Aug 11, 2025

Thank you!

I wonder if the warning is useful. save_hyperparameters is explicit that it picks up the argument from the class __init__, and field(init=False) quite explicitly excludes the attribute from the __init__.

I encountered #21036 by doing exactly what the warning suggests to do: by initializing the attributes in __post_init__ while using @dataclasses.dataclass.
The need for dataclasses.field(init=False) comes from type annotations: I want those __post_init__-initialized attributes (typically the nn.Module that make up the model) to be type-annotated (for better IDE integration/completion & type-checking). Since @dataclasses.dataclass would add any such attributes to the __init__ by default, I need to exclude them with init=False.

In that case, the warning would instruct the user to do exactly what they're already doing, which might be confusing and make them wonder what they should do exactly.

WDYT?

obj_fields = fields(obj)
init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init}
if any(not f.init for f in obj_fields):
rank_zero_warn(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See discussion above about this warning

@SkafteNicki
Copy link
Contributor Author

Thank you!

I wonder if the warning is useful. save_hyperparameters is explicit that it picks up the argument from the class __init__, and field(init=False) quite explicitly excludes the attribute from the __init__.

I encountered #21036 by doing exactly what the warning suggests to do: by initializing the attributes in __post_init__ while using @dataclasses.dataclass. The need for dataclasses.field(init=False) comes from type annotations: I want those __post_init__-initialized attributes (typically the nn.Module that make up the model) to be type-annotated (for better IDE integration/completion & type-checking). Since @dataclasses.dataclass would add any such attributes to the __init__ by default, I need to exclude them with init=False.

In that case, the warning would instruct the user to do exactly what they're already doing, which might be confusing and make them wonder what they should do exactly.

WDYT?

@QuentinSoubeyranAqemia what I was thinking that users would need to do, would be something like this

import dataclasses

import lightning.pytorch as L

@dataclasses.dataclass
class Module(L.LightningModule):
    a: float
    b: float
    c: float = 0.0

    def __post_init__(self):
        self.c = self.a + self.b
        self.save_hyperparameters()

model = Module(a=1, b=2)
print(model.hparams)

this is kind of misuse of dataclasses but should still have type checking and being correctly saved in self.hparams by using self.save_hyperparameters(). I really think that the warning should be there, but maybe the wording should be different.

@Borda Borda changed the title Fix save_hyperparameters not crashing on dataclass with `init=False Fix save_hyperparameters not crashing on dataclass with init=False Aug 11, 2025
@QuentinSoubeyranAqemia
Copy link

QuentinSoubeyranAqemia commented Aug 11, 2025

Perhaps the minimal reproducing example I provided is causing some confusion. In practice, the attributes that needs to be excluded with field(init=False) need to be excluded exactly because they are not hyper-parameters -- else they would be part of __init__.

The use-case is for internal nn.Module that need to be initialized in __post_init__ once dataclasses has stored the hparams into attributes for us. Those nn.Module attributes must not be part of __init__ nor save_hyperparameters. They cannot be initialized with a field(default=...) or field(default_factory=...) because they depend on the hyper-parameters.

Here's an example which is hopefully clearer: transforming a (simplistic) vanilla lightning module to leverage @dataclasses.dataclass to write the boilerplate of storing hparams into attribute for us:

import lightning.pytorch as L

class Module(L.LightningModule):
    input_dim: int
    hparam: int
    readout_dim: int

    model: MyModule
    readout: nn.Linear

    # we need to write a boilerplate signature again, though it is already laid out above
    def __init__(self, input_dim: int, hparam: int, readout_dim: int):
        super().__init__()
        self.save_hyperparameters()
        # boilerplate to store hyper-parameters
        sefl.input_dim = input_dim
        self.hparam = hparam
        self.readout_dim = readout_dim
        # create internals
        self.model = MyModule(self.input_dim, self.hparam) # nn.Module
        self.readout = nn.Linear(..., self.readout_dim)

into the more concise

import lightning.pytorch as L

import dataclasses

@dataclasses.dataclass(...) # some specific flag needed here, not the point of this discussion
class Module(L.LightningModule):
    input_dim: int
    hidden_dim: Sequence[int]
    readout_dim: int
    
    # internals, not hparams, exclude them from __init__ and save_hyperparameters()
    model: MyModule = dataclasses.field(init=False)
    readout: nn.Linear = dataclasses.field(init=False)

    def __post_init__(self):
        # no boilerplate !
        super().__init__()
        self.save_hyperparameters()
        self.model = MyModule(self.input_dim, self.hparam) # nn.Module
        self.readout = nn.Linear(..., self.readout_dim)

@SkafteNicki
Copy link
Contributor Author

@QuentinSoubeyranAqemia thanks for expanding on the use case, it all makes sense now. I have removed the warning since you are right that it does not make sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LightningModule.save_hyperparameters crash on dataclass with non-init fields
2 participants